Skip to content

Guard Windows ROCm torchao override skip#6837

Merged
Imagineer99 merged 22 commits into
unslothai:mainfrom
InfoSage05:issue-6833-windows-rocm-torchao
Jul 3, 2026
Merged

Guard Windows ROCm torchao override skip#6837
Imagineer99 merged 22 commits into
unslothai:mainfrom
InfoSage05:issue-6833-windows-rocm-torchao

Conversation

@InfoSage05

Copy link
Copy Markdown
Contributor

What changed

This PR hardens the Windows ROCm torchao skip path in studio/install_python_stack.py.

Previously, the installer skipped torchao only when _rocm_windows_torch_installed had been set earlier in the install flow. This change adds a direct runtime probe of the installed torch build and skips the torchao override whenever the target venv already contains a Windows ROCm torch wheel.

Why it changed

Issue #6833 reports that on Windows ROCm, torchao==0.17.0 crashes on import because the ROCm Windows torch build does not expose torch.ops._c10d_functional.all_gather_into_tensor.

The repo already had the main mitigation in place:

  • Studio runtime stubs torchao on Windows ROCm so transformers and peft imports do not crash.
  • install_python_stack.py skips installing torchao when the ROCm-install path marks _rocm_windows_torch_installed.

The remaining gap was that the skip depended on that earlier flag being present. If the flag was missed or stale but the environment already contained a ROCm torch wheel, the installer could still install torchao and ship a package that crashes on import.

Root cause

The issue is not general ROCm failure and not GPU detection failure. The root cause is that Windows ROCm torch can be present in the venv while the torchao override decision still relies on installer state rather than the actual installed torch runtime.

This PR closes that gap by checking the installed torch runtime directly before applying the torchao override.

User impact

  • Windows ROCm installs no longer depend solely on _rocm_windows_torch_installed to avoid installing incompatible torchao.
  • If the venv already has a ROCm torch build, the installer now skips torchao consistently.
  • This reduces the chance of Studio falling into the import-crash path described in #6833.

Validation

Focused tests:

  • pytest -q tests/studio/install/test_rocm_support.py -k 'RocmTorchInstalledEnvVar or WindowsRocmTorchaoGuard'

Broader installer coverage:

  • pytest -q tests/studio/install/test_rocm_support.py tests/studio/install/test_pr5940_followups.py

Results:

  • Focused slice: 8 passed
  • Broader slice: 378 passed, 4 skipped

Ayushman Paul added 3 commits July 2, 2026 12:39
When doing full finetuning (FFT) of a bfloat16 model, the fp16/bf16
mismatch validation fires before the corrective logic runs, causing a
misleading error even though the code would properly handle it downstream.
Skip the validation when full_finetuning is active.

Fixes unslothai#6731
…idation

Instead of entirely skipping validation (which could let mismatches
through when mixed_precision_dtype is float32), auto-correct explicit
fp16/bf16 settings that conflict with the model's dtype for FFT. This
way the existing validation still catches real mismatches for non-FFT
cases, and the corrective logic below handles the normalized settings.

Fixes the issue raised in Codex review of PR unslothai#6813.
Detect installed ROCm torch directly before applying the torchao override so Windows ROCm environments never install the crashing torchao package even if the earlier ROCm-installed flag is missing.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a check to detect Windows ROCm PyTorch builds to safely skip torchao installation, along with corresponding unit tests. It also updates the TRL RL trainer patching logic to automatically adjust precision flags (use_fp16/use_bf16) during full finetuning to match the model's dtype. The review feedback suggests making the ROCm torch probe more robust against stray stdout noise by checking the last non-empty line, and ensuring that the underlying args.fp16 and args.bf16 attributes are updated alongside the local variables to maintain consistency.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread studio/install_python_stack.py Outdated
Comment thread unsloth/models/rl.py
InfoSage05 and others added 7 commits July 3, 2026 14:20
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Tolerate stray stdout noise when probing Windows ROCm torch installs by checking the last non-empty output line, matching the existing torch version probe behavior. Also keep args.fp16 and args.bf16 synchronized with the full-finetuning precision auto-corrections in the RL trainer patch so downstream eval settings see a consistent TrainingArguments state.
@Imagineer99

Copy link
Copy Markdown
Collaborator

Main installer change looks correct.

I have one question, why is the unsloth/models/rl.py precision flag change included? The PR description is focused on Windows ROCm/torchao, was it intentional to include?

Could we also make the new guard test a bit more behavioral rather than checking the exact source string? For example, patch _installed_torch_is_windows_rocm() to return true and assert that the torchao install path is skipped. I think that would cover the intent more directly and be less brittle.

@InfoSage05

Copy link
Copy Markdown
Contributor Author

@Imagineer99 Yes! the unsloth/models/rl.py change was intentional, but it is unrelated to the Windows ROCm / torchao guard. It came from separate review feedback on the same branch, so I agree it makes this PR less focused. I’ll remove that RL precision-flag change from this PR and keep this branch scoped to the Windows ROCm installer behavior only.

Good point on the test as well. I’ll switch the new guard test to a behavioral one by patching _installed_torch_is_windows_rocm() to return True and asserting that the torchao install path is skipped, instead of asserting on the exact source string. That should better cover the intended behavior and avoid brittle source-level coupling.

@InfoSage05 InfoSage05 marked this pull request as ready for review July 3, 2026 12:51
Patch imported MLXTrainer and MLXTrainingConfig objects to preserve the expected dataclass field ordering and to provide a _train_dataset_for_batches fallback when older trainers or test doubles only expose train_dataset. Also add focused worker tests covering both compatibility paths.
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Repo admins can enable using credits for code reviews in their settings.

@Imagineer99

Copy link
Copy Markdown
Collaborator
image

I pushed the cleanup we discussed.

The PR is now scoped to the Windows ROCm torchao guard: the installer skips torchao when either the ROCm install marker is set or the installed torch runtime probes as Windows ROCm. I also removed the unrelated RL/MLX changes and switched the test coverage to exercise the behavioral skip path instead of relying on source-string assertions.

CI is green on the latest head, merging.

@Imagineer99 Imagineer99 merged commit c356427 into unslothai:main Jul 3, 2026
36 of 37 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants